# one-shot variable importance in-sample
library(tidyverse)
library(data.table)
library(dtplyr)
library(tidytext)
library(quanteda)
library(distrom)
library(textir)

# function to visualize rf variable importance
rf_to_viz <- function(rf_mod, out, xvars, sub, 
                      classification = FALSE, 
                      shrink = NULL){
  
  # get variable importance into data frame
  vimp_lpc <- 
    rf_mod$variable.importance %>%
    data.frame()
  
  names(vimp_lpc) <- "importance"
  vimp_lpc$var <- rownames(vimp_lpc)
  
  vimp_lpc <- vimp_lpc %>% filter(!var == "day_relative_1120")
  
  if(!is.null(shrink)){
    vimp_lpc <- vimp_lpc %>% filter(importance > shrink)
  }
  
  if(classification == FALSE){
    vimp_lpc_plot <- 
      vimp_lpc %>%
      ggplot(aes(x=reorder(var,importance), 
                 y=importance))+ 
      geom_bar(stat="identity", 
               position="dodge")+ 
      coord_flip()+
      ylab("Variable importance (permutation)")+
      xlab(xvars)+
      labs(title = paste0(out),
           subtitle = sub,
           caption = paste0("OOB R-Squared: ", round(rf_mod$r.squared, 3)))+
      guides(fill=F)+
      theme_bw()
  }
  
  if(classification == TRUE){
    vimp_lpc_plot <- 
      vimp_lpc %>%
      ggplot(aes(x=reorder(var,importance), 
                 y=importance))+ 
      geom_bar(stat="identity", 
               position="dodge")+ 
      coord_flip()+
      ylab("Variable importance (permutation)")+
      xlab(xvars)+
      labs(title = paste0(out),
           subtitle = sub,
           caption = paste0("OOB Prediction Error: ", round(rf_mod$prediction.error, 3)))+
      guides(fill=F)+
      theme_bw()
  }
  return(vimp_lpc_plot)
}

covid.tweets.nrt <- data.table::fread("~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/covid_tweets_nrt_4120.csv")

covid.tweets.nrt[,week := lubridate::week(date)]
covid.tweets.nrt[,month := lubridate::month(date)]

# generalize hashtags starting with 'china' to 'china_hashtag'
covid.tweets.nrt[,lowtext := stringr::str_replace_all(lowtext, "\\s(#china[\\w]+)", " china_hashtag")]

covid.tweets.nrt <- as_tibble(covid.tweets.nrt)

set.seed(11111)

corp <- quanteda::corpus(covid.tweets.nrt, text_field= "lowtext", docid_field= "index")

message("Setting compounds...")

compounds <- quanteda::phrase(
  c("white house","fox news","united states",
    "united nations",
    "social media","climate change","right wing",
    "green new deal","nuclear power","cap and trade","clean coal", "gun violence","national security",
      "tax cut","cut tax","cut taxes", "president trump", "american people", "trump administration",
      "medicare for all","health care", "health insurance",
      "single payer", "assault weapon","assault weapons","semi automatic","second amendment","brady bill",
      "high capacity magainze","high capacity magazines","bump stock","bump stocks","background check",
      "background checks","build the wall","wall funding","birthright citizenship","nuclear deal",
      "pro life","pro choice","anti choice","born alive","partial birth","late term",
      "house democrats","house republicans","senate democrats","senate republicans",
      "majority leader","minority leader","town hall","donald trump","social security",
      "law enforcement","preexisting conditions",
      
      # covid compounds
      "world health organization","centers for disease control",
      "supply chain","shelter in place","defense production act",
      "personal protective equipment","laid off", "town hall",
      "wish list","op ed")
)

message("Tokenizing...")

# tokenize/preprocess/remove extraneous twitter stuff
tok <- quanteda::tokens(corp,
                        remove_punct=T,
                        remove_twitter=T, 
                        remove_numbers = T,
                        remove_url=T) %>%
  quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
  quanteda::tokens_remove("via", valuetype = "regex") %>%
  
  # remove stop words
  quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
  
  # concatenate common bigrams
  quanteda::tokens_compound(pattern = compounds) %>%
  quanteda::tokens_wordstem() %>%
  quanteda::tokens_ngrams(1:3)

# make full document frequency matrix
dfm.full <-  dfm(tok,  verbose=T,tolower = T)

# discard very frequent/infrequent terms
message("Trimming dfm...")

dfm.lim <- dfm_trim(
  dfm.full, 
  min_docfreq = 100/nrow(covid.tweets.nrt), 
  max_docfreq = .5,
  docfreq_type = "prop")

message(paste0("Retained ", ncol(dfm.lim), " tokens for analysis..."))

message("Prepping IV and DV...")

# put back in matrix
X <- as.matrix(dfm.lim)

dates <- covid.tweets.nrt[match(rownames(X), covid.tweets.nrt$index), "day_relative_1120"]
  
X <- cbind(X, dates)

# make outcome vector
Y<- covid.tweets.nrt[match(rownames(X), covid.tweets.nrt$index), "is_republican"] %>% pull("is_republican")

# discard empty docs
drops <- which(rowSums(X)==0)

if(length(drops) > 0){
  Y <- Y[-as.numeric(drops)]
  X <- X[-as.numeric(drops), ]
}

datbind <- cbind(Y, X)

# model
rf1 <- ranger::ranger(data= datbind,
                      dependent.variable.name = "Y",
                      classification = TRUE,
                      probability = TRUE,
                      importance = "permutation",
                      verbose = TRUE)

# run model
Y_mir <- cbind(Y, dates)
X_mir <- X[,-which(colnames(X) == "day_relative_1120")]

mod1 <- dmr(covars=Y_mir, counts=X_mir, cl=NULL, verb=2, gamma=0, nlambda=100)

message("Prepping data...")

# store coefs
c1 <- coef(mod1)

# make coefficient frame 
tmp <- as.data.frame(as.matrix(t(c1)))
tmp$word <- rownames(tmp)

# put term frequencies back on coefficient data frame
xm <- data.frame(word = colnames(dfm.lim),
                 freq = colSums(dfm.lim))
xm$word <- as.character(xm$word)

tmp <- tmp %>% left_join(xm, by = "word")

# get variable importance into data frame
vimp_lpc <- 
 rf1$variable.importance %>%
  data.frame()

names(vimp_lpc) <- "importance"
vimp_lpc$var <- rownames(vimp_lpc)

imp_with_coefs <-
  vimp_lpc %>% left_join(tmp[,c("word","Y","freq")], 
                         by = c("var" = "word")) %>%
  filter(!var == "day_relative_1120") %>%
  filter(importance > .001) %>%
  mutate(fillcol = case_when(Y > 0 ~ "red",
                             Y < 0 ~ "blue",
                             Y == 0 | is.na(Y) ~ "black"),
         val = ifelse(fillcol == "blue", -1*importance, importance)) %>%
  ggplot(aes(x=reorder(var,val), 
             y=val,
             fill = factor(fillcol, levels = c("blue","red","black"))))+ 
  geom_bar(stat="identity", 
           position="dodge")+ 
  scale_fill_manual(name = "Generally indicates author is a...",
                    breaks = c("blue", "red","black"),
                    values = c("darkblue","darkred","black"),
                    labels = c("Democrat","Republican","No Direction"))+
  coord_flip()+
  scale_y_continuous(name = "Random Forest Permutation Importance",
                     breaks = seq(from = -.0075, to = .0075, by = .0025),
                     labels = paste0(abs(seq(from = -.0075, to = .0075, by = .0025))))+
  xlab("Features")+
  labs(title = "Most Important Text Features",
       subtitle = "Partisan association defined by sign of multinomial inverse regression coefficient",
       caption = paste0("OOB Prediction Error: ", round(rf1$prediction.error, 3)))+
  theme_bw()+
  theme(text = element_text(family = "serif"),
        plot.title = element_text(size = 24),
        plot.subtitle = element_text(size =14),
        plot.caption = element_text(size = 12, face = "bold", hjust = 0),
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 12))
ggsave(imp_with_coefs, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/figure_s6.png", width = 12, height = 6)

save(vimp_lpc, tmp, rf1, file = "~/Dropbox/CHAMP-Net/coronavirus_paper/data_and_code/replication_file/output/covid_token_importance.RData")

